{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def strassen(A, B):\n",
    "    k = A.shape[0] // 2\n",
    "    if k == 0:\n",
    "        return A * B\n",
    "\n",
    "    A11, A12 = A[:k, :k], A[:k, k:]\n",
    "    A21, A22 = A[k:, :k], A[k:, k:]\n",
    "    B11, B12 = B[:k, :k], B[:k, k:]\n",
    "    B21, B22 = B[k:, :k], B[k:, k:]\n",
    "    \n",
    "    T1 = strassen(A11 + A22, B11 + B22)\n",
    "    T2 = strassen(A21 + A22, B11)\n",
    "    T3 = strassen(A11, B12 - B22)\n",
    "    T4 = strassen(A22, B21 - B11)\n",
    "    T5 = strassen(A11 + A12, B22)\n",
    "    T6 = strassen(A21 - A11, B11 + B12)\n",
    "    T7 = strassen(A12 - A22, B21 + B22)\n",
    "    \n",
    "    C = np.zeros(A.shape, dtype=A.dtype)\n",
    "    C[:k, :k] = T1 + T4 - T5 + T7\n",
    "    C[:k, k:] = T3 + T5\n",
    "    C[k:, :k] = T2 + T4\n",
    "    C[k:, k:] = T1 - T2 + T3 + T6\n",
    "    \n",
    "    return C"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[9, 1, 1, 2, 5, 2, 8, 3],\n",
       "       [4, 7, 1, 9, 6, 6, 9, 5],\n",
       "       [0, 1, 0, 1, 7, 6, 0, 8],\n",
       "       [9, 0, 4, 3, 3, 4, 4, 5],\n",
       "       [6, 9, 4, 7, 1, 5, 5, 0],\n",
       "       [3, 5, 8, 1, 6, 2, 3, 8],\n",
       "       [3, 1, 0, 4, 2, 5, 1, 7],\n",
       "       [2, 6, 4, 6, 3, 4, 6, 3]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = np.random.randint(0, 10, (8, 8))\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[8, 1, 1, 7, 6, 2, 7, 3],\n",
       "       [2, 3, 6, 5, 8, 2, 0, 0],\n",
       "       [5, 5, 1, 6, 3, 4, 4, 9],\n",
       "       [4, 0, 6, 0, 4, 0, 7, 8],\n",
       "       [2, 7, 0, 7, 3, 1, 5, 2],\n",
       "       [8, 1, 9, 3, 3, 7, 6, 2],\n",
       "       [8, 5, 7, 4, 7, 8, 9, 7],\n",
       "       [4, 6, 2, 9, 2, 8, 0, 7]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = np.random.randint(0, 10, (8, 8))\n",
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[189, 112, 108, 174, 156, 131, 190, 143],\n",
       "       [239, 153, 228, 210, 228, 186, 242, 215],\n",
       "       [100, 106,  82, 144,  67, 115,  78,  90],\n",
       "       [194, 104, 105, 181, 137, 137, 175, 164],\n",
       "       [196,  90, 186, 153, 201, 122, 187, 157],\n",
       "       [162, 165, 102, 226, 147, 156, 129, 182],\n",
       "       [122,  72,  99, 122,  84, 109,  98, 111],\n",
       "       [170, 113, 162, 152, 165, 135, 165, 167]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strassen(X, Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
